{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "isp7wiMgfXik"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "imdb_train = load_dataset(\"imdb\", split=\"train[:1000]+train[-1000:]\")\n",
        "\n",
        "imdb_test = load_dataset(\"imdb\", split=\"test[:500]+test[-500:]\")\n",
        "\n",
        "imdb_val = load_dataset(\"imdb\", split=\"test[500:1000]+test[-1000:-500]\")\n",
        "\n",
        "imdb_train.shape, imdb_test.shape, imdb_val.shape\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import BertTokenizerFast, BertForSequenceClassification\n",
        "\n",
        "model_path = \"bert-base-uncased\"\n",
        "\n",
        "tokenizer = BertTokenizerFast.from_pretrained(model_path)\n",
        "\n",
        "\n",
        "enc_train = imdb_train.map(tokenize_it, batched=True, batch_size=1000)\n",
        "\n",
        "enc_test = imdb_test.map(tokenize_it, batched=True, batch_size=1000)\n",
        "\n",
        "enc_val = imdb_val.map(tokenize_it, batched=True, batch_size=1000)\n"
      ],
      "metadata": {
        "id": "MZoNMut6ffTp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def hp_space(trial):\n",
        "\n",
        "    hp = {\n",
        "        \"learning_rate\": trial.suggest_float(\"learning_rate\", 1e-6, 1e-4, log=True),\n",
        "        \"per_device_train_batch_size\": trial.suggest_categorical(\n",
        "            \"per_device_train_batch_size\", [8, 16]\n",
        "        ),\n",
        "    }\n",
        "\n",
        "\n",
        "return hp\n"
      ],
      "metadata": {
        "id": "lSeHdMIwflcn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "best_run= trainer.hyperparameter_search(\n",
        "\n",
        "   direction=\"maximize\",\n",
        "\n",
        "  hp_space=hp_space)"
      ],
      "metadata": {
        "id": "B8X0GYHSf11d"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}